from tensorflow.python import keras

class CdssmBase():
    """docstring for TextCNNSmall"""
    def __init__(self, model_conf, train_conf={"batch_size":64,"epochs":3, "verbose":1}):
        self.model_conf = model_conf
        self.train_conf = train_conf
        self.__prepare_moel__()
        self.__build_structure__()

    def __prepare_moel__(self):
        self.embedding = keras.layers.Embedding(output_dim = self.model_conf["w2c_len"],
                                    input_dim = len(self.model_conf["emb_model"].embedding_weights), 
                                    weights=[self.model_conf["emb_model"].get_np_weights()], 
                                    input_length=self.model_conf["MAX_LEN"], 
                                    trainable=True
                                    )

        self.conv1ds = []
        self.max_pools = []
        for i, filter_size in enumerate([1,2,3,4,5,6,7,8]):
            self.conv1ds.append(keras.layers.Conv1D(100, filter_size, padding="valid"))
            self.max_pools.append(keras.layers.MaxPool1D(pool_size=self.model_conf["MAX_LEN"] - filter_size + 1))

        self.leakyRelu = keras.layers.LeakyReLU(alpha = 0.2)
        self.concatenate = keras.layers.Concatenate(axis=1)
        self.flatten = keras.layers.Flatten(name="embedding_result")

        self.dot = keras.layers.Dot(axes=-1, normalize=True)

    def __build_model__(self, input_info):
        emb = self.embedding(input_info)
        feature = []
        for i in range(6):
            c = self.conv1ds[i](emb)
            l = self.leakyRelu(c)
            m = self.max_pools[i](l)
            feature.append(m)
        feature = self.concatenate(feature)
        emb_output = self.flatten(feature)

        return emb_output 


    def __build_structure__(self):
        inputs_1 = keras.layers.Input((self.model_conf["MAX_LEN"]), name="inputs_1")
        inputs_2 = keras.layers.Input((self.model_conf["MAX_LEN"]), name="inputs_2")

        emb_input1 = self.__build_model__(inputs_1)
        emb_input2 = self.__build_model__(inputs_2)
  
        cos_scores = []
        cos_scores.append(self.dot([emb_input1, emb_input2]))

        cos_scores = self.concatenate(cos_scores)

        input_train = [inputs_1, inputs_2]
        self.model = keras.Model(inputs=input_train, outputs=cos_scores)
        self.model.summary()
        self.model.compile(loss="mse", optimizer="adam", metrics=['mse', 'accuracy'])


        # inputs_1 = keras.layers.Input(shape=(self.model_conf["MAX_LEN"],), name="inputs_1")
        # inputs_2 = keras.layers.Input(shape=(self.model_conf["MAX_LEN"],), name="inputs_2")
        
        # embedding_layer = keras.layers.Embedding(output_dim = self.model_conf["w2c_len"],
        #                             input_dim = len(self.model_conf["emb_model"].embedding_weights), 
        #                             weights=[self.model_conf["emb_model"].get_np_weights()], 
        #                             input_length=self.model_conf["MAX_LEN"], 
        #                             trainable=True
        #                             )

        # embedding_seq_1 = embedding_layer(inputs_1)
        # embedding_seq_2 = embedding_layer(inputs_2)


        # l_conv1 = keras.layers.Conv1D(filters=self.model_conf["w2c_len"], kernel_size=3, activation='relu')(embedding_seq_1)
        # l_pool1 = keras.layers.MaxPool1D(pool_size=3)(l_conv1)

        # l_conv2 = keras.layers.Conv1D(filters=self.model_conf["w2c_len"], kernel_size=3, activation='relu')(embedding_seq_2)
        # l_pool2 = keras.layers.MaxPool1D(pool_size=3)(l_conv2)

        # flatten_1 = keras.layers.Lambda(lambda x: keras.backend.mean(x, axis=1), name="flatten_1")(l_pool1)
        # flatten_2 = keras.layers.Lambda(lambda x: keras.backend.mean(x, axis=1), name="flatten_2")(l_pool2)

        # output = keras.layers.dot([flatten_1, flatten_2], axes=1)

        # pred = keras.layers.Dense(units=1, activation='sigmoid')(output)
         
        # self.model = keras.models.Model(inputs=[inputs_1, inputs_2], outputs=pred)
        # self.model.summary()
        # self.model.compile(loss="binary_crossentropy", optimizer="adam", metrics=['accuracy'])

    def fit(self, x_train, y_train, x_test, y_test):
        # names = [layer.name for layer in self.model.layers]
        # [print(name) for name in names]
        history = self.model.fit(x_train, y_train, batch_size=self.train_conf.get("batch_size", 64),
                                epochs=self.train_conf.get("epochs", 3),
                                validation_data=(x_test, y_test),
                                verbose=self.train_conf.get("verbose", 1))

        self.model_vec_1 = keras.models.Model(inputs=self.model.get_layer('inputs_1').input, outputs=self.model.get_layer('embedding_result').output)
        # self.model_vec_2 = keras.models.Model(inputs=self.model.get_layer('inputs_2').input, outputs=self.model.get_layer('embedding_result').output)
        return history

    def evaluate(self, x_test, y_test):
        return self.model.evaluate(x_test, y_test)

    def predict(self, sentences):
        return self.model.predict(sentences)

    def predict_vec(self, sentences):
        return self.model_vec_1.predict(sentences), self.model_vec_1.predict(sentences) 

    def save(self, path):
        if self.model:
            self.model.save(path)

    def load(self, path):
        self.model = keras.load_model(path)